# Exemple d'optimalité entre les 3 fct MH, MH_high_dim et MH_high_dim_para sur le modèle non lineaire mixte
m <- function(t, phi1, phi2, phi3) (phi1  )/(1+exp((phi2-t)/phi3))
#=======================================#
p <- 100
parameter <- list(sigma2 = .05^2,
                  #rho2 = 5,
                  mu = c(0.9,90,5),
                  omega2 = c(0.005, 40, 1),
                  #S.data data,
                  bara = 90,
                  barb = 30,
                  baralpha = 0.5,
                  beta = rep(0,p))
parameter$beta[1:4] <- c(-.8, -.2 , .3 , .9)
#=======================================#
t <- seq(60,120, length.out = 10) #time values
set.seed(123)

G <- 40 ; ng = 4
# link = function(t,phi1, phi2, phi3) phi2#/phi3
link = m

dt <- create_JLS_HD_data(G, ng, t, m, link, parameter)

var.true <- dt$var.true
a <- var.true$a ; var.true$a <- NULL #a fixé (et retiré des variables latentes)
S.data <- dt$survival
U <- dt$U

Y <- do.call(get_obs, var.true) + rnorm(n, 0, sqrt(parameter$sigma2))
S.data.time <- S.data$obs
S.data.time.log.sum <- sum(log(S.data.time))
longitudinal_plot <- data.frame(time, Y, id = rep(1:N, each = length(t)), gen = rep(1:G, each = ng*length(t)) ) %>%
  ggplot(aes(time, Y, col = factor(gen), group = factor(id) )) +
  geom_point() + geom_line() +
  theme(legend.position = 'null')

S.data_plot <- S.data %>% ggplot(aes(obs)) +#, fill = U))) +
    geom_histogram(col = 'white', position = 'identity', bins = 30) + theme(legend.position = 'null')

grid.arrange(longitudinal_plot, S.data_plot, nrow = 2)

model <- SAEM_model( 
  function(sigma2, ...) -n/(2*sigma2),
  function(phi1, phi2, phi3, ...) mean((Y - get_obs(phi1, phi2, phi3) )^2 ), 'sigma2',
  
  # === Variable Latente === #
  latent_vars = list(
    # === Non linear model === #
    latent_variable('phi', dim = G, size = 3, prior = list(mean = 'mu', variance = 'omega2'),
                    add_on = c('zeta(phi1 = phi1, phi2 = phi2, phi3 = phi3, ...)' )),
    
    # === S.data model === #
    latent_variable('b', prior = list(mean = 'barb', variance.hyper = 'sigma2_b'),
                    add_on = c('zeta(b = b, ...) +',
                               'sum(h$eval(b = b, ..., i = c(1,2)))' )),
    latent_variable('alpha', prior = list(mean = 'baralpha', variance.hyper = 'sigma2_alpha'),
                    add_on = c('zeta(alpha = alpha, ...) +',
                               'alpha*h$eval(alpha = alpha,..., i = 3)'))
  ),

  # === Paramètre de regression === #
  regression.parameter = list(
    regression_parameter('beta', 1, function(...) SPGD(1, theta0 = beta,
                                                      step = 0.05, lambda = 1/sqrt(N),
                                                      normalized.grad = T,
                                                      zeta.der.B, N, zeta.B, 
                                                      Z$alpha,  Z$phi1, Z$phi2, Z$phi3,Z$b) )
  )
)
# ---  Initialisation des paramètres --- #
parameter0 <- parameter %>% sapply(function(x) x* runif(1, 1.1,1.4))
parameter0$beta <- runif(p, min = -1, max = 1)

#===============================================#
load.SAEM(model)
S.tmp <- do.call(S$eval, var.true)
oracle <- maximisation(1, do.call(S$eval, var.true), parameter, var.true)
#==============================================================================#

init.options <- list(x0 = list(phi = c(1,80,4), b = parameter0$barb, alpha = parameter0$baralpha), 
                     sd = list(phi = c(.05, 1.5, .5), b = 1, alpha = .1) )

SAEM.options <- list(niter = 200, sim.iter = 5, burnin = 190, 
                adptative.sd = 0.6)

saem

res <- run(model, parameter0, init.options, SAEM.options, verbatim = 3)
saveRDS(res, paste0(params$rds_filename, '_', p, '.rds'))

# = = = = = = = = = # = = = = = = = = = # = = = = = = = = = # = = = = = = = = = # = = = = = = = = = #

plot(res, true.value = oracle, exclude = 'beta')
## [1] "SAEM execution time = 00h 38min 33sec"
## $plot_parameter

## 
## $plot_MCMC

## 
## $plot_acceptation

plot_high_dim(res, oracle, 'beta', zeta, dec = 0, 
              var.true$alpha, var.true$phi1, var.true$phi2, var.true$phi3, var.true$b)
## [[1]]

## 
## [[2]]

## 
## [[3]]

plot(res, true.value = oracle, var = 'summary', exclude = 'beta', time = F)
Result of the SAEM-MCMC
sigma2 mu.1 mu.2 mu.3 omega2.1 omega2.2 omega2.3 barb baralpha
Real value 0.0026 0.9032 89.9575 5.0079 0.0039 35.9179 0.6948 30.0000 0.5000
Estimated value 0.0026 0.9003 89.9695 4.9188 0.0045 35.1034 0.6817 37.0008 2.2036
Rrmse 0.0036 0.0032 0.0001 0.0178 0.1333 0.0227 0.0188 0.2334 3.4072

saem

load.options <- list(exclude.maximisation = c('baralpha') )
parameter0[load.options$exclude.maximisation] <- parameter[load.options$exclude.maximisation]

res <- run(model, parameter0, init.options, SAEM.options,load.options, verbatim = 3)
saveRDS(res, paste0(params$rds_filename, '_', p, '.rds'))

# = = = = = = = = = # = = = = = = = = = # = = = = = = = = = # = = = = = = = = = # = = = = = = = = = #

plot(res, true.value = oracle, exclude = 'beta')
## [1] "SAEM execution time = 00h 38min 44sec"
## $plot_parameter

## 
## $plot_MCMC

## 
## $plot_acceptation

plot_high_dim(res, oracle, 'beta', zeta, dec = 0, 
              var.true$alpha, var.true$phi1, var.true$phi2, var.true$phi3, var.true$b)
## [[1]]

## 
## [[2]]

## 
## [[3]]

plot(res, true.value = oracle, var = 'summary', exclude = 'beta', time = F)
Result of the SAEM-MCMC
sigma2 mu.1 mu.2 mu.3 omega2.1 omega2.2 omega2.3 barb baralpha
Real value 0.0026 0.9032 89.9575 5.0079 0.0039 35.9179 0.6948 30.0000 0.5
Estimated value 0.0026 0.9016 89.9684 5.0528 0.0040 35.6588 0.8133 38.4418 0.5
Rrmse 0.0041 0.0018 0.0001 0.0090 0.0089 0.0072 0.1707 0.2814 0.0

saem

load.options <- list(exclude.maximisation = c('baralpha', 'barb') )
parameter0[load.options$exclude.maximisation] <- parameter[load.options$exclude.maximisation]

res <- run(model, parameter0, init.options, SAEM.options, load.options, verbatim = 3)
saveRDS(res, paste0(params$rds_filename, '_', p, '.rds'))

# = = = = = = = = = # = = = = = = = = = # = = = = = = = = = # = = = = = = = = = # = = = = = = = = = #

plot(res, true.value = oracle, exclude = 'beta')
## [1] "SAEM execution time = 00h 35min 02sec"
## $plot_parameter

## 
## $plot_MCMC

## 
## $plot_acceptation

plot_high_dim(res, oracle, 'beta', zeta, dec = 0, 
              var.true$alpha, var.true$phi1, var.true$phi2, var.true$phi3, var.true$b)
## [[1]]

## 
## [[2]]

## 
## [[3]]

plot(res, true.value = oracle, var = 'summary', exclude = 'beta', time = F)
Result of the SAEM-MCMC
sigma2 mu.1 mu.2 mu.3 omega2.1 omega2.2 omega2.3 baralpha barb
Real value 0.0026 0.9032 89.9575 5.0079 0.0039 35.9179 0.6948 0.5 30
Estimated value 0.0025 0.9021 89.9561 4.9671 0.0044 35.4356 0.7924 0.5 30
Rrmse 0.0056 0.0012 0.0000 0.0081 0.1145 0.0134 0.1406 0.0 0